{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Graph Neural Networks Implementation\n",
    "This notebook implements different types of Graph Neural Networks using PyTorch Geometric (PyG)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: torch-geometric in /Users/lakshya/miniconda3/envs/applied-graph-nn-book/lib/python3.12/site-packages (2.6.1)\n",
      "Requirement already satisfied: aiohttp in /Users/lakshya/miniconda3/envs/applied-graph-nn-book/lib/python3.12/site-packages (from torch-geometric) (3.11.8)\n",
      "Requirement already satisfied: fsspec in /Users/lakshya/miniconda3/envs/applied-graph-nn-book/lib/python3.12/site-packages (from torch-geometric) (2024.10.0)\n",
      "Requirement already satisfied: jinja2 in /Users/lakshya/miniconda3/envs/applied-graph-nn-book/lib/python3.12/site-packages (from torch-geometric) (3.1.4)\n",
      "Requirement already satisfied: numpy in /Users/lakshya/miniconda3/envs/applied-graph-nn-book/lib/python3.12/site-packages (from torch-geometric) (1.26.4)\n",
      "Requirement already satisfied: psutil>=5.8.0 in /Users/lakshya/miniconda3/envs/applied-graph-nn-book/lib/python3.12/site-packages (from torch-geometric) (5.9.0)\n",
      "Requirement already satisfied: pyparsing in /Users/lakshya/miniconda3/envs/applied-graph-nn-book/lib/python3.12/site-packages (from torch-geometric) (3.2.0)\n",
      "Requirement already satisfied: requests in /Users/lakshya/miniconda3/envs/applied-graph-nn-book/lib/python3.12/site-packages (from torch-geometric) (2.32.3)\n",
      "Requirement already satisfied: tqdm in /Users/lakshya/miniconda3/envs/applied-graph-nn-book/lib/python3.12/site-packages (from torch-geometric) (4.67.1)\n",
      "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /Users/lakshya/miniconda3/envs/applied-graph-nn-book/lib/python3.12/site-packages (from aiohttp->torch-geometric) (2.4.3)\n",
      "Requirement already satisfied: aiosignal>=1.1.2 in /Users/lakshya/miniconda3/envs/applied-graph-nn-book/lib/python3.12/site-packages (from aiohttp->torch-geometric) (1.3.1)\n",
      "Requirement already satisfied: attrs>=17.3.0 in /Users/lakshya/miniconda3/envs/applied-graph-nn-book/lib/python3.12/site-packages (from aiohttp->torch-geometric) (24.2.0)\n",
      "Requirement already satisfied: frozenlist>=1.1.1 in /Users/lakshya/miniconda3/envs/applied-graph-nn-book/lib/python3.12/site-packages (from aiohttp->torch-geometric) (1.5.0)\n",
      "Requirement already satisfied: multidict<7.0,>=4.5 in /Users/lakshya/miniconda3/envs/applied-graph-nn-book/lib/python3.12/site-packages (from aiohttp->torch-geometric) (6.1.0)\n",
      "Requirement already satisfied: propcache>=0.2.0 in /Users/lakshya/miniconda3/envs/applied-graph-nn-book/lib/python3.12/site-packages (from aiohttp->torch-geometric) (0.2.0)\n",
      "Requirement already satisfied: yarl<2.0,>=1.17.0 in /Users/lakshya/miniconda3/envs/applied-graph-nn-book/lib/python3.12/site-packages (from aiohttp->torch-geometric) (1.18.0)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /Users/lakshya/miniconda3/envs/applied-graph-nn-book/lib/python3.12/site-packages (from jinja2->torch-geometric) (2.1.3)\n",
      "Requirement already satisfied: charset-normalizer<4,>=2 in /Users/lakshya/miniconda3/envs/applied-graph-nn-book/lib/python3.12/site-packages (from requests->torch-geometric) (3.3.2)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /Users/lakshya/miniconda3/envs/applied-graph-nn-book/lib/python3.12/site-packages (from requests->torch-geometric) (3.7)\n",
      "Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/lakshya/miniconda3/envs/applied-graph-nn-book/lib/python3.12/site-packages (from requests->torch-geometric) (2.2.3)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /Users/lakshya/miniconda3/envs/applied-graph-nn-book/lib/python3.12/site-packages (from requests->torch-geometric) (2024.8.30)\n"
     ]
    }
   ],
   "source": [
    "# Install required packages if not already installed\n",
    "!pip install torch-geometric"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import required libraries\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.datasets import Planetoid\n",
    "from torch_geometric.nn import GCNConv, SAGEConv, GATConv"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load and Explore Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index\n",
      "Processing...\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of nodes: 2708\n",
      "Number of edges: 10556\n",
      "Number of features: 1433\n",
      "Number of classes: 7\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Done!\n"
     ]
    }
   ],
   "source": [
    "# Load the Cora dataset\n",
    "dataset = Planetoid(root='data/Cora', name='Cora')\n",
    "\n",
    "# Get the graph data\n",
    "data = dataset[0]\n",
    "\n",
    "# Print dataset statistics\n",
    "print(f'Number of nodes: {data.num_nodes}')\n",
    "print(f'Number of edges: {data.num_edges}')\n",
    "print(f'Number of features: {data.num_features}')\n",
    "print(f'Number of classes: {dataset.num_classes}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Graph Convolutional Network (GCN)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GCN(torch.nn.Module):\n",
    "    def __init__(self, in_channels, hidden_channels, out_channels):\n",
    "        super(GCN, self).__init__()\n",
    "        self.conv1 = GCNConv(in_channels, hidden_channels)\n",
    "        self.conv2 = GCNConv(hidden_channels, out_channels)\n",
    "\n",
    "    def forward(self, x, edge_index):\n",
    "        x = self.conv1(x, edge_index)\n",
    "        x = F.relu(x)\n",
    "        x = F.dropout(x, training=self.training)\n",
    "        x = self.conv2(x, edge_index)\n",
    "        return F.log_softmax(x, dim=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## GraphSAGE Implementation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GraphSAGE(torch.nn.Module):\n",
    "    def __init__(self, in_channels, hidden_channels, out_channels):\n",
    "        super(GraphSAGE, self).__init__()\n",
    "        self.conv1 = SAGEConv(in_channels, hidden_channels)\n",
    "        self.conv2 = SAGEConv(hidden_channels, out_channels)\n",
    "\n",
    "    def forward(self, x, edge_index):\n",
    "        x = self.conv1(x, edge_index)\n",
    "        x = F.relu(x)\n",
    "        x = F.dropout(x, training=self.training)\n",
    "        x = self.conv2(x, edge_index)\n",
    "        return F.log_softmax(x, dim=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Graph Attention Network (GAT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GAT(torch.nn.Module):\n",
    "    def __init__(self, in_channels, hidden_channels, out_channels, heads=8, dropout=0.6):\n",
    "        super(GAT, self).__init__()\n",
    "        self.conv1 = GATConv(in_channels, hidden_channels, heads=heads, dropout=dropout)\n",
    "        self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1, dropout=dropout)\n",
    "\n",
    "    def forward(self, x, edge_index):\n",
    "        x = self.conv1(x, edge_index)\n",
    "        x = F.elu(x)\n",
    "        x = F.dropout(x, training=self.training)\n",
    "        x = self.conv2(x, edge_index)\n",
    "        return F.log_softmax(x, dim=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training and Evaluation Function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_and_evaluate(model, data, epochs=200):\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)\n",
    "    criterion = torch.nn.NLLLoss()\n",
    "    \n",
    "    # Training\n",
    "    model.train()\n",
    "    for epoch in range(epochs):\n",
    "        optimizer.zero_grad()\n",
    "        out = model(data.x, data.edge_index)\n",
    "        loss = criterion(out[data.train_mask], data.y[data.train_mask])\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        if (epoch + 1) % 50 == 0:\n",
    "            print(f'Epoch {epoch+1:03d}, Loss: {loss.item():.4f}')\n",
    "    \n",
    "    # Evaluation\n",
    "    model.eval()\n",
    "    _, pred = model(data.x, data.edge_index).max(dim=1)\n",
    "    correct = float(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())\n",
    "    accuracy = correct / data.test_mask.sum().item()\n",
    "    print(f'Test Accuracy: {accuracy:.4f}')\n",
    "    \n",
    "    return accuracy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train and Compare Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training GCN...\n",
      "Epoch 050, Loss: 0.0619\n",
      "Epoch 100, Loss: 0.0328\n",
      "Epoch 150, Loss: 0.0300\n",
      "Epoch 200, Loss: 0.0346\n",
      "Test Accuracy: 0.7950\n",
      "\n",
      "Training GraphSAGE...\n",
      "Epoch 050, Loss: 0.0217\n",
      "Epoch 100, Loss: 0.0266\n",
      "Epoch 150, Loss: 0.0175\n",
      "Epoch 200, Loss: 0.0109\n",
      "Test Accuracy: 0.7830\n",
      "\n",
      "Training GAT...\n",
      "Epoch 050, Loss: 0.3255\n",
      "Epoch 100, Loss: 0.3336\n",
      "Epoch 150, Loss: 0.3269\n",
      "Epoch 200, Loss: 0.3211\n",
      "Test Accuracy: 0.7900\n",
      "\n",
      "Final Results:\n",
      "GCN Accuracy: 0.7950\n",
      "GraphSAGE Accuracy: 0.7830\n",
      "GAT Accuracy: 0.7900\n"
     ]
    }
   ],
   "source": [
    "# Model parameters\n",
    "in_channels = dataset.num_node_features\n",
    "hidden_channels = 16\n",
    "out_channels = dataset.num_classes\n",
    "\n",
    "# Initialize models\n",
    "gcn_model = GCN(in_channels, hidden_channels, out_channels)\n",
    "sage_model = GraphSAGE(in_channels, hidden_channels, out_channels)\n",
    "gat_model = GAT(in_channels, hidden_channels, out_channels)\n",
    "\n",
    "# Train and evaluate each model\n",
    "print(\"Training GCN...\")\n",
    "gcn_acc = train_and_evaluate(gcn_model, data)\n",
    "\n",
    "print(\"\\nTraining GraphSAGE...\")\n",
    "sage_acc = train_and_evaluate(sage_model, data)\n",
    "\n",
    "print(\"\\nTraining GAT...\")\n",
    "gat_acc = train_and_evaluate(gat_model, data)\n",
    "\n",
    "# Compare results\n",
    "print(\"\\nFinal Results:\")\n",
    "print(f\"GCN Accuracy: {gcn_acc:.4f}\")\n",
    "print(f\"GraphSAGE Accuracy: {sage_acc:.4f}\")\n",
    "print(f\"GAT Accuracy: {gat_acc:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
